package pers.vic.sso.client.filter;


import org.apache.commons.lang3.StringUtils;
import pers.vic.boot.base.model.BaseResponse;
import pers.vic.sso.common.constant.Oauth2Constant;
import pers.vic.sso.common.model.RpcAccessToken;
import pers.vic.sso.common.model.SessionAccessToken;
import pers.vic.sso.common.util.SsoOauth2Util;
import pers.vic.sso.common.util.SsoSessionUtil;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.util.function.Consumer;

/**
 * 描述:
 * 登录过滤器
 *
 * @author Vic.xu
 * @date 2021-11-02 9:44
 */
public class LoginFilter extends BaseClientFilter {

    //登录成功的回调
    protected Consumer<RpcAccessToken> afterLogin;

    public void setAfterLogin(Consumer<RpcAccessToken> afterLogin) {
        this.afterLogin = afterLogin;
    }


    /**
     * <p>
     * 1. 判断本地session是否存在；
     * 2. 如果存在，则判断是否过期
     * 3. 如果过期，则使用refreshToken前往服务端获刷新token,延长周期，获取新的accessToken
     * 4. 若果token不存在，或者过期，或者无法延期：
     * 5. 则获取请求中的授权码
     * 6. 若获取不到授权码，则前往登录页面
     * 7. 通过授权码拿到accessToken(且存储到本地),则去掉url中的code，再次重定向到当前地址
     * </p>
     *
     * @param request
     * @param response
     * @return
     * @throws IOException
     */
    @Override
    public boolean isAccessAllowed(HttpServletRequest request, HttpServletResponse response) throws IOException {
        SessionAccessToken sessionAccessToken = SsoSessionUtil.getAccessToken(request);
        // 本地Session中已存在，且accessToken没过期或者refreshToken成功，直接返回
        if (sessionAccessToken != null && (!sessionAccessToken.isExpired()
                || refreshToken(sessionAccessToken.getRefreshToken(), request))) {
            return true;
        }
        String code = request.getParameter(Oauth2Constant.AUTH_CODE);
        if (StringUtils.isEmpty(code)) {
            redirectLogin(request, response);
            return false;
        }

        //获取到code，则利用code获取AccessToken ，然后去掉code重定向到当前地址
        getAccessToken(code, request);
        redirectLocalRemoveCode(request, response);
        return false;
    }

    /**
     * 重定向到当前地址，并去掉url中的code
     *
     * @param request
     * @param response
     * @throws IOException
     */
    protected void redirectLocalRemoveCode(HttpServletRequest request, HttpServletResponse response) throws IOException {
        String currentUrl = getCurrentUrl(request);
        currentUrl = currentUrl.substring(0, currentUrl.indexOf(Oauth2Constant.AUTH_CODE) - 1);
        response.sendRedirect(currentUrl);
    }

    private RpcAccessToken getAccessToken(String code, HttpServletRequest request) {
        BaseResponse<RpcAccessToken> result = SsoOauth2Util.getAccessToken(getServerUrl(), getAppId(),
                getAppSecret(), code);
        if (!result.isSuccess()) {
            logger.error("getAccessToken has error, message:{}", result.getMsg());
            return null;
        }
        RpcAccessToken accessToken = result.getData();
        //保存到session
        setAccessTokenInSession(accessToken, request);
        //登录成功后的回调
        if (afterLogin != null) {
            afterLogin.accept(accessToken);
        }
        return accessToken;
    }


    @Override
    protected String getRedirectUrl(HttpServletRequest request) {
        return getCurrentUrl(request);
    }

    /**
     * 获取当前请求地址
     *
     * @param request
     * @return
     */
    private String getCurrentUrl(HttpServletRequest request) {
        return new StringBuilder().append(request.getRequestURL())
                .append(request.getQueryString() == null ? "" : "?" + request.getQueryString()).toString();
    }


    /**
     * 通过refreshToken参数调用http请求延长服务端session，并返回新的accessToken
     *
     * @param refreshToken
     * @param request
     * @return
     */
    protected boolean refreshToken(String refreshToken, HttpServletRequest request) {
        logger.info("start refreshToken, refreshToken = {}", refreshToken);
        BaseResponse<RpcAccessToken> result = SsoOauth2Util.refreshToken(getServerUrl(), getAppId(), refreshToken);
        if (!result.isSuccess()) {
            logger.error("refreshToken has error, message:{}", result.getMsg());
            return false;
        }
        return setAccessTokenInSession(result.getData(), request);
    }

    private boolean setAccessTokenInSession(RpcAccessToken rpcAccessToken, HttpServletRequest request) {
        if (rpcAccessToken == null) {
            return false;
        }
        // 记录accessToken到本地session
        SsoSessionUtil.setAccessToken(request, rpcAccessToken);
        // 记录本地session和accessToken映射
        recordSession(request, rpcAccessToken.getAccessToken());
        return true;
    }

    private void recordSession(final HttpServletRequest request, String accessToken) {
        final HttpSession session = request.getSession();
        getSessionMappingStorage().removeBySessionById(session.getId());
        getSessionMappingStorage().addSessionById(accessToken, session);
    }


}
